import torch
import torch.nn as nn
import torchvision
import torchvision.models as models
import gym
from scipy.optimize import linear_sum_assignment  # 匈牙利算法
import torch.nn.functional as F
import numpy as np
from stable_baselines3 import PPO
import torch.optim as optim
from scipy.spatial.distance import cdist
from sklearn.cluster import KMeans, AgglomerativeClustering, DBSCAN
from sklearn.manifold import TSNE
from sklearn.preprocessing import StandardScaler
from torch.distributions import Categorical
from sklearn.metrics.pairwise import cosine_similarity
from sklearn.mixture import GaussianMixture
import torchvision.transforms as transforms
from torchvision import transforms
from PIL import Image
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA
from sklearn.mixture import GaussianMixture
import random, copy
from stable_baselines3.common.vec_env import DummyVecEnv


device = torch.device("cuda:" if torch.cuda.is_available() else "cpu")


def ppo_state(env):
    # 使用Stable-Baselines3的DummyVecEnv封装
    # env = DummyVecEnv([lambda: env])
    obs = env.reset()
    print(f"Reset observation: {obs}")
    # 创建PPO模型
    model = PPO("CnnPolicy", env, verbose=1, learning_rate=2.5e-4, n_steps=128, batch_size=64, n_epochs=4, gamma=0.99)

    # 训练10000步
    model.learn(total_timesteps=10000)

    # 收集state数据
    state_buffer = []
    obs = env.reset()  # reset() 返回的是一个状态数组，形状为 (1, observation_space)
    obs = obs[0]  # 取出单一环境的状态from efficientnet_pytorch import EfficientNet
    # obs = env.reset()
    for _ in range(10000):
        action, _ = model.predict(obs, deterministic=True)
        obs, _, done, _ = env.step(action)
        state_buffer.append(obs)
        if done:
            obs = env.reset()
            obs = obs[0]  # 重置时获取环境的状态

    # 转换为NumPy数组
    state_buffer = np.array(state_buffer)

    print(f"Collected {len(state_buffer)} states.")
    print(state_buffer.shape)
    return state_buffer


class AtariResNetFeatureExtractor(nn.Module):
    def __init__(self, output_dim=512, resnet_type="resnet18"):
        super(AtariResNetFeatureExtractor, self).__init__()

        # 选择 ResNet 结构 (resnet18 或 resnet34)
        if resnet_type == "resnet18":
            resnet = models.resnet18(pretrained=True)
        else:
            raise ValueError("Unsupported ResNet type. Choose 'resnet18' or 'resnet34'.")

        # 修改 ResNet 第一层卷积，使其支持 1 通道输入 (灰度图)
        self.resnet = resnet
        # self.resnet.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)

        self.transform = transforms.Compose([
            transforms.Grayscale(num_output_channels=3),  # 转换为 3 通道的图像
            transforms.Resize(224),  # 调整为 224x224
            # transforms.RandomHorizontalFlip(),  # 随机水平翻转
            # transforms.RandomRotation(20),  # 随机旋转
            # transforms.RandomResizedCrop(224),  # 随机裁剪
            transforms.ToTensor(),  # 转换为 Tensor
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # 进行标准化
        ])

        # 冻结所有参数（不训练）
        for param in self.resnet.parameters():
            param.requires_grad = False

        # 去掉 ResNet 的全连接层，只保留 CNN 部分作为特征提取器
        self.resnet = nn.Sequential(*list(self.resnet.children())[:-1])

        # 附加一个线性层以调整特征维度（默认512维）
        # self.fc = nn.Linear(512, output_dim)
        self.fc = nn.Linear(384, output_dim)

    def forward(self, x):
        x = torch.stack([self.transform(torchvision.transforms.functional.to_pil_image(img)) for img in x])
        x = x.to(device)
        features = self.resnet(x)  # (B, 512, 1, 1)
        features = features.view(features.size(0), -1)  # (B, 512)
        return features


def extract_features_in_batches(states_tensor, batch_size, feature_extractor, device):
    """
    分批次提取特征并合并。

    :param states_tensor: 输入的状态张量，形状为 (500, 3, 84, 84)
    :param batch_size: 每次处理的批量大小
    :param feature_extractor: 预训练的特征提取器
    :param device: 计算设备（CPU/GPU）
    :return: 提取的特征，形状为 (500, 1, 512)
    """
    feature_extractor.eval()  # 确保模型处于评估模式

    feature_list = []
    for batch in torch.split(states_tensor.to(device), batch_size):
        with torch.no_grad():  # 禁用梯度计算，节省显存
            features = feature_extractor(batch)  # 提取特征
        feature_list.append(features)

    # 合并特征
    final_features = torch.cat(feature_list, dim=0)  # (500, 1, 512)
    return final_features




